package org.zjvis.datascience.common.util.db;

import cn.hutool.db.Db;
import cn.hutool.db.DbUtil;
import cn.hutool.db.Entity;
import cn.hutool.db.dialect.impl.PostgresqlDialect;
import cn.hutool.db.ds.simple.SimpleDataSource;
import cn.hutool.db.meta.Column;
import cn.hutool.db.meta.Table;
import cn.hutool.db.sql.Wrapper;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.zjvis.datascience.common.constant.DatabaseConstant;
import org.zjvis.datascience.common.dto.dataset.DatasetColumnDTO;
import org.zjvis.datascience.common.exception.BaseErrorCode;
import org.zjvis.datascience.common.exception.DataScienceException;
import org.zjvis.datascience.common.model.ColumnSchema;
import org.zjvis.datascience.common.model.TableSchema;
import org.zjvis.datascience.common.pool.BasePool;

import javax.sql.DataSource;
import java.sql.*;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

/**
 * @description : JBDC连接帮助类
 * @date 2021-11-25
 */
@Service
public class JDBCUtil {
    private final static Logger logger = LoggerFactory.getLogger("JDBCUtil");

    {
        DbUtil.setReturnGeneratedKeyGlobal(false);
    }

    public static final Wrapper GP_WRAPPER = new PostgresqlDialect().getWrapper();
    public static final String MYSQL_GET_GEO_TEXT_FUNCTION_TEMPLATE = "ST_AsText({})";

    @Autowired
    private BasePool pool;

    @Value("${data.importBatchSize: 5000}")
    private int importBatchSize;

    public static final ConcurrentHashMap<String, Integer> IMPORT_TASK_PROGRESS = new ConcurrentHashMap<>();

    /**
     * 关闭连接
     *
     * @param conn
     * @param stmt
     * @param rs
     */
    public static void close(Connection conn, PreparedStatement stmt, ResultSet rs) {
        if (null != rs) {
            try {
                rs.close();
            } catch (SQLException e) {
                logger.error(e.getMessage(), e);
            }
            rs = null;
        }

        if (null != stmt) {
            try {
                stmt.close();
            } catch (SQLException e) {
                logger.error(e.getMessage(), e);
            }
            stmt = null;
        }

        if (null != conn) {
            try {
                conn.close();
            } catch (SQLException e) {
                logger.error(e.getMessage(), e);
            }
            conn = null;
        }
    }

    /**
     * 关闭连接
     *
     * @param conn
     * @param stmt
     * @param rs
     */
    public static void close(Connection conn, Statement stmt, ResultSet rs) {
        if (null != rs) {
            try {
                rs.close();
            } catch (SQLException e) {
                logger.error(e.getMessage(), e);
            }
            rs = null;
        }

        if (null != stmt) {
            try {
                stmt.close();
            } catch (SQLException e) {
                logger.error(e.getMessage(), e);
            }
            stmt = null;
        }

        if (null != conn) {
            try {
                conn.close();
            } catch (SQLException e) {
                logger.error(e.getMessage(), e);
            }
            conn = null;
        }
    }

    /**
     * 获取jdbc连接
     *
     * @param url
     * @param user
     * @param password
     * @return
     */
    public static Connection getConnection(String url, String user, String password) throws SQLException {
        //设置连接超时时间
        DriverManager.setLoginTimeout(DatabaseConstant.LOGIN_TIMEOUT);
        return DriverManager.getConnection(url, user, password);
    }

    /**
     * 获取当前连接下的所有库名
     *
     * @param con
     * @param sql
     * @return
     */
    public static List<String> getAllDBs(Connection con, String sql) {
        List<String> res = new ArrayList<>();

        try {
            PreparedStatement ps = con.prepareStatement(sql);
            ResultSet rs = ps.executeQuery();
            while (rs.next()) {
                res.add(rs.getString(1));
            }
        } catch (SQLException e) {
            logger.error(e.getMessage(), e);
        }
        return res;
    }

    /**
     * 获取当前连接下的所有表名
     *
     * @param con
     * @return
     */
    public static List<String> getAllTables(Connection con, String sql) throws SQLException {
        List<String> res = new ArrayList<>();

        PreparedStatement ps = con.prepareStatement(sql);
        ResultSet rs = ps.executeQuery();
        while (rs.next()) {
            res.add(rs.getString(1));
        }
        return res;
    }

    /**
     * 根据连接获取指定表的表结构
     *
     * @param con
     * @param tbName
     * @return
     */
    public static TableSchema getTableSchema(Connection con, String tbName) {
        TableSchema ts = new TableSchema();
        ts.setTableName(tbName);

        if (con == null) {
            return null;
        }
        try {
            DatabaseMetaData md = con.getMetaData();
            ResultSet colSet = md.getColumns(con.getCatalog(), "%", tbName, "%");
            List<ColumnSchema> columns = new ArrayList<>();

            while (colSet.next()) {
                ColumnSchema ds = new ColumnSchema();

                String type = colSet.getString("TYPE_NAME");
                String columnSize = colSet.getString("COLUMN_SIZE");

                ds.setColumnName(colSet.getString("COLUMN_NAME"));
                ds.setColumnType(type);
                if (StringUtils.isNumeric(columnSize)) {
                    ds.setColumnSize(Integer.valueOf(columnSize));
                }
                columns.add(ds);
            }
            ts.setColumns(columns);

        } catch (SQLException e) {
            logger.error(e.getMessage(), e);
        }
        return ts;
    }

    /**
     * 根据ResultSet的ResultSetMetaData获取表结构
     *
     * @param rsm
     * @return
     */
    public static TableSchema getTableSchema(ResultSetMetaData rsm) throws SQLException {
        TableSchema ts = new TableSchema();
        if (rsm == null) {
            return null;
        }

        List<ColumnSchema> columns = new ArrayList<>();
        int count = rsm.getColumnCount();
        ts.setTableName(rsm.getTableName(1));

        for (int i = 1; i < count + 1; i++) {
            ColumnSchema ds = new ColumnSchema();

            ds.setColumnName(rsm.getColumnName(i));
            ds.setColumnType(rsm.getColumnTypeName(i));
            ds.setColumnSize(rsm.getColumnDisplaySize(i));

            columns.add(ds);
        }
        ts.setColumns(columns);
        return ts;
    }

    /**
     * 从数据库读取数据
     *
     * @param con
     * @param sql
     * @return
     */
    public static ResultSet read(Connection con, String sql) {
        ResultSet rs = null;

        try {
            PreparedStatement ps = con.prepareStatement(sql);
            rs = ps.executeQuery();
        } catch (SQLException e) {
            logger.error(e.getMessage(), e);
        } catch (Exception e2) {
            throw new DataScienceException(BaseErrorCode.DATASET_QUERY_ERROR, e2);
        }
        return rs;
    }

//    /**
//     * 预处理，删除表、创建表、清表
//     *
//     * @param con
//     * @param ts
//     * @param type
//     * @return
//     */
//    public static boolean pretreatment(Connection con, TableSchema ts, String type, String sourceTargetType) {
//        PreparedStatement ps = null;
//        try {
//            String createSql = SqlUtil.generateCreateTableSql(ts, sourceTargetType);
//            System.out.println(createSql);
//            String executeSql = "";
//            switch (type) {
//                case "DROP_CREATE":
//                    executeSql = "drop table if exists " + ts.getTableName() + ";";
//                    break;
//                case "CREATE":
//                    executeSql = createSql;
//                    break;
//                case "TRUNCATE":
//                    executeSql = "truncate table " + ts.getTableName() + ";";
//                    break;
//                default:
//                    break;
//            }
//            ps = con.prepareStatement(executeSql);
//            ps.execute();
//            if ("DROP_CREATE".equals(type)) {
//                ps.execute(createSql);
//            }
//
//            return true;
//        } catch (SQLException e) {
//            LogUtil.error(LogEnum.DATASET, e.getMessage(), e);
//            close(con, ps, null);
//        }
//        return false;
//    }

    public static boolean execute(Connection conn, String sql) throws SQLException {
        try (PreparedStatement ps = conn.prepareStatement(sql)) {
            ps.execute();
        }
        return true;
    }

    public static DataSource getDataSource(String url, String user, String password) {
        SimpleDataSource ds = new SimpleDataSource(url, user, password);
        ds.addConnProps("connectTimeout", "6000000");
        ds.addConnProps("autoReconnect", "true");
        return ds;
    }

    /**
     * 获取url
     *
     * @param server
     * @param port
     * @param databaseName
     * @param databaseType
     * @param connectType  连接方式(Oracle专用)
     * @param connectValue 服务名或SID(Oracle专用)
     * @return
     */
    public static String getUrl(String server, Integer port, String databaseName, String databaseType, String connectType, String connectValue) {
        String url = "";

        switch (databaseType == null ? "" : databaseType.toLowerCase()) {
            case "oracle":
                if (StringUtils.isNotBlank(connectValue)) {
                    databaseName = connectValue;
                }
                switch (connectType == null ? "" : connectType.toLowerCase()) {
                    case "servicename":
                        url = DatabaseConstant.ORACLE_JDBC_SERVICE_NAME_URL;
                        break;
                    default:
                        url = DatabaseConstant.ORACLE_JDBC_SID_URL;
                }
                break;
            default:
                url = DatabaseConstant.MYSQL_JDBC_URL;
        }
        return String.format(url, server, port, databaseName);
    }

    /**
     * 清除表
     *
     * @param con
     * @param tables
     * @return
     */
    public static boolean cleanupTable(Connection con, Set<String> tables) {
        try {
            Statement ps = con.createStatement();
            for (String table : tables) {
                ps.execute("drop table if exists " + table + ";");
            }
        } catch (SQLException e) {
            logger.error(e.getMessage(), e);
            return false;
        }
        return true;
    }

    public static Entity buildEntity(Table meta) {
        Entity entity = Entity.create(meta.getTableName());
        entity.setFieldNames(meta.getColumns().stream().map(Column::getName).collect(Collectors.toList()));
        return entity;
    }

    public static Map<String, String> buildQueryFunctions(Table tbMeta) {
        Collection<Column> columns = tbMeta.getColumns();
        Map<String, String> functions = new HashMap<>();
        for (Column column : columns) {
            if ("GEOMETRY".equals(column.getTypeName())) {
                functions.put(column.getName(), MYSQL_GET_GEO_TEXT_FUNCTION_TEMPLATE);
            }
        }
        return functions;
    }

    public List<Entity> pageForTransferData(DataSource fromDs, Table tbMeta, String targetTable, int page, String databaseType, boolean preview, List<DatasetColumnDTO> columnMessage, String lastValue, String incrementColumn) throws SQLException {
        Db db = DbUtil.use(fromDs);
        Wrapper wrapper = FunctionWrapper
                .build(db.getRunner().getDialect().getWrapper())
                .addFunctions(buildQueryFunctions(tbMeta));
        db.setWrapper(wrapper);

        Entity buildEntity = buildEntity(tbMeta);
        if (!lastValue.equals("-1")) {
            buildEntity.set(incrementColumn, " > " + lastValue);
        }

        return db.page(buildEntity, page, importBatchSize,
                new TransformEntityListHandler(DatabaseConstant.GREEN_PLUM_DEFAULT_SCHEMA + "." + targetTable, databaseType, preview, columnMessage));
    }

    /**
     * 分页读取数据
     *
     * @param ds           DataSource
     * @param tbMeta       表meta
     * @param page         页数
     * @param pageSize     每页大小
     * @param databaseType 数据库类型
     * @param preview      是否用于预览
     * @return
     * @throws SQLException
     */
    public static List<Entity> page(DataSource ds, Table tbMeta, int page, int pageSize, String databaseType, boolean preview) throws SQLException {
        Db db = DbUtil.use(ds);
        switch (databaseType == null ? "" : databaseType.toLowerCase()) {
            case "oracle":
                db.getRunner().getDialect().getWrapper().setPreWrapQuote('\"');
                db.getRunner().getDialect().getWrapper().setSufWrapQuote('\"');
                break;
        }
        Wrapper wrapper = FunctionWrapper
                .build(db.getRunner().getDialect().getWrapper())
                .addFunctions(buildQueryFunctions(tbMeta));
        db.setWrapper(wrapper);

        return db.page(buildEntity(tbMeta), page, pageSize, new TransformEntityListHandler(null, databaseType, preview, null));
    }

    public void writeToTable(DataSource fromDs, Table tbMeta, DataSource toDs, String targetTable, int dataSize, String sourceDbType, List<DatasetColumnDTO> columnMessage, String lastValue, String incrementColumn) throws Exception {
        int pageSize = (int) Math.ceil(dataSize * 1.0 / importBatchSize);
        CompletableFuture[] futures = new CompletableFuture[pageSize];
        IMPORT_TASK_PROGRESS.putIfAbsent(targetTable, 0);
        for (int i = 0; i < pageSize; i++) {
            int j = i;
            futures[i] = CompletableFuture.runAsync(() -> {
                if (!IMPORT_TASK_PROGRESS.containsKey(targetTable)) {
                    return;
                }
                try {
                    List<Entity> data = pageForTransferData(fromDs, tbMeta, targetTable, j, sourceDbType, false, columnMessage, lastValue, incrementColumn);
                    doWriteToTable(data, toDs);
                    IMPORT_TASK_PROGRESS.computeIfPresent(targetTable, (k, v) -> v + data.size());
                } catch (SQLException e) {
                    throw new CompletionException(e);
                } catch (NullPointerException e) {
                    logger.warn("task has been canceled", targetTable);
                }
            }, pool.getExecutor());
        }

        try {
            CompletableFuture.allOf(futures).join();
        } catch (CompletionException e) {
            throw new DataScienceException(BaseErrorCode.DATASET_IMPORT_GP_INSERT_ERROR, e);
        } finally {
            IMPORT_TASK_PROGRESS.remove(targetTable);
        }
    }

    public static void doWriteToTable(List<Entity> data, DataSource ds) throws SQLException {
        logger.info(String.format("doWriteToTable size:%d", data.size()));
        long start = System.currentTimeMillis();
        boolean success = false;
        int retry = 0;
        while (!success && retry++ < 5) {
            try {
                DbUtil.use(ds).setWrapper(GP_WRAPPER).insert(data);
                success = true;
            } catch (SQLTransientConnectionException e) {
                logger.warn("insert data timeout", e);
                try {
                    Thread.sleep(1000);
                } catch (InterruptedException ignore) {
                }
            }
        }
        logger.info(String.format("doWriteToTable cost:%d ms", System.currentTimeMillis() - start));
    }

    public static int getDataSize(DataSource fromDs, String sourceTable, String incrementColumn, String lastValue) throws Exception {
        Db db = DbUtil.use(fromDs);
        String sql = "select count(1) from `" + sourceTable +
                "` where " + incrementColumn + " > '" +
                lastValue + "'";
        return db.queryNumber(sql).intValue();
    }
}
